rm(list = ls())
gc()
require(tidyverse)
require(fixest)
require(lubridate)
require(caret)
require(glmnet)
require(glmnetUtils)
require(fastDummies)

setwd('/scratch/jhb362/zilinsky_2023/data/')

# Helper function for temporal transformation
temp.fun <- function(x,type = "lag",n = 1) {
  if(type == "lag") {
    return(lag(x,n = n))
  }
  if(type == "chg") {
    return(x - lag(x,n = n))
  }
  if(type == "pctChg") {
    return((x - lag(x,n = n))/lag(x,n = n))
  }
  if(type == "mav") {
    x2 <- NULL
    for(i in 0:n) {
      x2 <- x2 + lag(x,n = i)
    }
    return((sum(x2))/n)
  }
  if(type == "mavWgt") {
    return(x*.5 + lag(x)*.2 + lag(x,n=2)*.15 + lag(x,n=3)*.1 + lag(x,n=4)*.05)
  }
}


args <- commandArgs(trailingOnly = T)
# args <- c('OBAMAPP','months','lag1') # months, quarters, years, all
out <- as.character(args[1])
temp <- as.character(args[3])
set.seed(123)

cat(Sys.getenv("SLURM_NTASKS_PER_NODE"),'\n')
if(Sys.getenv("SLURM_NTASKS_PER_NODE") != "") {
  ncores <- as.integer(Sys.getenv("SLURM_NTASKS_PER_NODE"))
} else {
  ncores <- 1 #parallel::detectCores()
}

# Most important predictors for economic evaluations
load('./final_01_2024.RData')

ctyDat <- ctyDat %>%
  rename_at(vars(-matches('stcou|year|date|^DEM|^HLTH|^CRIME')),function(x) paste0('ECON_',gsub('LAU_','',x))) %>% 
  select(-matches('_NA_|all_ind|_wages$'),-ECON_emp,-ECON_lf,-ECON_unemp)

# Prep on the county-level data
# Dropping crime data which is crazy skewed
require(moments)
skews <- sapply(ctyDat %>% select(matches('CRIME_ARRST')) %>% select(-matches('_lb|_ub')),function(x) skewness(x,na.rm=T))
drops <- skews[-which(skews < 6)]

ctyDat <- ctyDat %>% select(-matches(paste(gsub('_pct','',names(drops)),collapse = '|')))


indivCovs <- c('AGECAT','PARTY','RACE','GENDER','EDUC','MARST','ANNUALINC')

econCovs <- colnames(ctyDat %>% select(matches('^ECON_(ur|lfpr|tot_aww|.*lq_emplvl)')))
healthCovs <- colnames(ctyDat %>% select(matches('^HLTH_(Male|Female)_Total')) %>% select(-matches('_lb$|_ub$')))
crimeCovs <- colnames(ctyDat %>% select(matches('^CRIME_')) %>% select(-matches('_lb_|_ub_')))
demogCovs <- colnames(ctyDat %>% select(matches('^DEM_')) %>% select(-matches('_lb$|_ub$')))

gal$ZIPDEM_LFPR = gal$ZIPDEM_LF / gal$ZIPDEM_totpop

zipCovs <- colnames(gal %>% select(matches('^ZIP')) %>% select(-matches('CODE|_LF$')))

ctyDat <- ctyDat %>%
  mutate(DEM_pop_tot = log(DEM_pop_tot))


lgs <- paste0(gsub('_pct','',names(skews[which(skews < 6)])),c('_pct','_lb_pct','_ub_pct'))
scls <- c(econCovs,healthCovs,crimeCovs,demogCovs)
fcts <- c(indivCovs)

ctyDat <- ctyDat %>%
  mutate_at(vars(lgs),function(x) log(x+1))

# Temporal Transformation: okay so since some of these are monthly data and the rest
#   are annual, only apply the transformation to the monthly data. This basically means
#   we treat demographics and crime as influencing opinions due to their LEVELS, while
#   we treat health and economics as influencing opinions due to their CHANGES (although
#   this only matters for when we ask for temporal transformations via the command line.)
# test <- ctyDat %>%
#   select(stcou,year,all_of(c(econCovs,healthCovs,crimeCovs,demogCovs))) %>%
#   group_by(stcou,year) %>%
#   summarise_all(var,na.rm=T)
# 
# test2<-sapply(test,function(x) all(x == 0))
# test2[-which(test2)]
# ctyDat %>%
#   select(stcou,date,year,ECON_ur) %>%
#   distinct() %>%
#   group_by(stcou) %>%
#   arrange(date) %>%
#   mutate(ECON_ur = temp.fun(ECON_ur,type = args[5])) %>%
#   ungroup() %>%
#   filter(stcou != '') %>%
#   arrange(stcou,date)

ctyDat <- ctyDat %>%
  group_by(stcou) %>%
  arrange(date) %>%
  mutate_at(vars(colnames(ctyDat %>% select(matches('ECON_|HLTH_')) %>% 
                            select(-matches('_aww')))),
            function(x) temp.fun(x,type = gsub('\\d+','',temp),
                                 n = as.numeric(str_extract(temp,'\\d+')))) %>% # NB: temp.fun should be transformed IN TERMS OF MONTHS
  ungroup()

dat <- gal %>%
  rename(year = YEAR,wgt = COMB_STATEWT) %>%
  left_join(ctyDat %>% select(year,stcou,date,DEM_pop_tot,all_of(c(econCovs,healthCovs,crimeCovs,demogCovs))))


if(as.character(args[2]) == 'all') {
  dat$time = 'all'
} else {
  dat <- dat %>%
    mutate(time = as.Date(lubridate::round_date(date,unit = gsub('s$','',as.character(args[2])))))
}


# Permutation test + OLS
load('./gallup/var_lookup.RData')

lookup2 <- c('ID' = 'MOTHERLODE_ID',
             'UNION' = 'D17A',
             'EDUC' = 'EDUCATION',
             'RELIG' = 'D8B',
             'ANNUALINC' = 'INCOME_SUMMARY',
             'MONTHINC' = 'MONTHLY_INCOME',
             'RACE' = 'RACE',
             'SELFEMP' = 'WP10202',
             'AGE' = 'WP1220',
             'MARST' = 'WP1223',
             'OCC' = 'WP1225',
             'GENDER' = 'SC7',
             'IDEO' = 'P20',
             'PARTY' = 'PARTY',
             'HLTHINS' = 'H14',
             'SOLSATCOMP' = 'HWB17',
             'COMMHAP' = 'HWB22',
             'MONWORRY' = 'HWB6',
             'ECON' = 'M30',
             'WATCHSPEND' = 'M91',
             'MAJORPURCH' = 'M92',
             'CUTBACKSPEND' = 'M93',
             'FEELGOODMON' = 'M94',
             'MONWORRYYEST' = 'M95',
             'MONMORETHANENOUGH' = 'M96',
             'ENOUGHMON' = 'M97',
             'FEELBETTERFIN' = 'M97A',
             'TRUMPAPP' = 'P1167',
             'TRUMPAPPLN' = 'P1167F',
             'OBAMAPP' = 'P128',
             'HILFAV' = 'P919A',
             'VAGOVFAV' = 'P919W',
             'NATECONIMP' = 'WP148',
             'LIFETOD' = 'WP16',
             'LIFE5YR' = 'WP18',
             'SOLSAT' = 'WP30',
             'HAPPY' = 'WP6878',
             'WORRY' = 'WP69',
             'SAD' = 'WP70',
             'STRESS' = 'WP71',
             'SWB' = 'WELL_BEING_INDEX')

lookup2 <- data.frame(varsNEW = names(lookup2),vars = lookup2) %>%
  as_tibble()

lookup <- lookup %>%
  left_join(lookup2) %>%
  drop_na(varsNEW) %>%
  select(-vars) %>%
  rename(vars = varsNEW)

lookup %>%
  filter(vars == 'SOLSATCOMP')

lookup %>%
  filter(val_lab == '') %>%
  count(vars)

lookup %>%
  count(vars) %>%
  print(n = 39)

lookup <- lookup %>%
  mutate(val_lab = ifelse(vars %in% c('COMMHAP','MONWORRY','SOLSATCOMP') & vals == 2,'Somewhat disagree',
                          ifelse(vars %in% c('COMMHAP','MONWORRY','SOLSATCOMP') & vals == 3,'Neutral',
                                 ifelse(vars %in% c('COMMHAP','MONWORRY','SOLSATCOMP') & vals == 4,'Somewhat agree',
                                        ifelse(vars %in% c('COMMHAP','MONWORRY','SOLSATCOMP') & vals == 9,'REF',val_lab))))) %>%
  # bind_rows(data.frame(vars = 'COMMHAP',var_lab = 'You are proud of your community or the area where you live.',vals = 9,val_lab = 'REF')) %>%
  bind_rows(data.frame(vars = rep('AGECAT',6),
                       var_lab = 'Age category',
                       vals = 1:6,
                       val_lab = c('LT25','25-34','35-44','45-54','55-64','65Up'))) %>%
  filter(!vars %in% c('TRUMPAPP','ENOUGHMON','HILFAV','SOLSAT','WORRY','SAD','STRESS','HAPPY','SELFEMP','OBAMAPP','VAGOVFAV',
                      'ECON','NATECONIMP','LIFETOD','LIFE5YR','AGE','FEELBETTERFIN','WATCHSPEND','COMMHAP','SOLSATCOMP','MONWORRY',
                      'MONMORETHANENOUGH','MONWORRYYEST','FEELGOODMON','CUTBACKSPEND','MAJORPURCH','TRUMPAPPLN')) %>%
  mutate(val_lab = substr(gsub('_{2,}','_',gsub(' |-|–|\\/|—| ','_',gsub("\\(|\\)|,|\\$|\\.|\\[|\\]|'",'',val_lab))),1,30))

# Reduce complexity of EDUC & INC
lookup <- lookup %>%
  filter(vars != 'EDUC',
         vars != 'ANNUALINC')


dat <- dat %>%
  mutate(EDUC = factor(ifelse(EDUC == 1,'LTHS',
                              ifelse(EDUC == 2,'HSDeg',
                                     ifelse(EDUC %in% c(3:5),'SoCo',
                                            ifelse(EDUC %in% 6:8,'CollDegUp',NA)))),
                       levels = c('LTHS','HSDeg','SoCo','CollDegUp')),
         ANNUALINC = factor(ifelse(ANNUALINC %in% 1:4,'LT_23999',
                             ifelse(ANNUALINC %in% 5:6,'24000_47999',
                                    ifelse(ANNUALINC %in% 7:8,'48000_89999',
                                           ifelse(ANNUALINC %in% 9:10,'90000_up','DK_REF')))),
                      levels = c('LT_23999','24000_47999','48000_89999','90000_up','DK_REF')))

for(i in colnames(dat)) {
  if(i %in% unique(lookup$vars)) {
    # stop()
    tmp <- lookup %>%
      filter(vars == i,
             vals %in% unique(dat[[i]]))
    dat[[i]] <- factor(dat[[i]],labels = tmp$val_lab)
  }
}

dat <- dat %>%
  mutate(PARTY = relevel(PARTY,ref = 'Independent_no_lean'))

# Preparing ranger
Xs <- c(indivCovs,econCovs,healthCovs,crimeCovs,demogCovs)

Ys <- c('TRUMPAPP','ENOUGHMON','HILFAV','SOLSAT','WORRY','SAD','STRESS','HAPPY','OBAMAPP','VAGOVFAV',
        'ECON','NATECONIMP','LIFETOD','LIFE5YR','FEELBETTERFIN','WATCHSPEND','COMMHAP','SOLSATCOMP','MONWORRY',
        'MONMORETHANENOUGH','MONWORRYYEST','FEELGOODMON','CUTBACKSPEND','MAJORPURCH','ECONBIN','NATECONIMPBIN')

dat <- dat %>%
  mutate(COMMHAP = as.numeric(COMMHAP)) %>%
  mutate(COMMHAP = ifelse(COMMHAP > 5,NA,COMMHAP)) %>% 
  filter(PARTY != 'REF')


gc()
Y <- Ys[as.numeric(args[1])] # Changed this so we don't have to try and remember every outcome code!
lassoRes <- coefRes <- NULL
zz <- Sys.time()
for(d in unique(dat$time)) {
  # stop()
  
  if(class(d) == 'numeric') {
    d <- as.Date(d,origin = '1970-01-01')
  }
  
  tmpDat <- dat %>% 
    filter(time %in% d) %>%
    select(Y,Xs,STATE) %>% drop_na()
  
  if(nrow(tmpDat) == 0) { next }
  
  # stop()
  for(bsInd in 1:100) {
    # stop()
    
    tmpDat2 <- tmpDat %>%
      sample_n(size = round(.63*nrow(tmpDat)),replace = T)
    
    tmpDatD <- fastDummies::dummy_cols(tmpDat2%>% select(-STATE),
                                       remove_selected_columns = T,
                                       remove_first_dummy = T)
    
    
    pp <- preProcess(tmpDatD %>% select(-Y),method = c("center","scale"))
    ppX <- try(predict(pp,tmpDatD %>% select(-Y)))
    if(class(ppX)[1] == "try-error") { stop() }
    
    fcts2 <- sapply(ppX,function(x) length(unique(x)))
    fcts3 <- names(fcts2[which(fcts2 < 5)])
    counts <- tmpDatD %>%
      summarise_at(vars(fcts3),sum) %>%
      gather(vars,n)
    counts <- counts %>%
      bind_rows(data.frame(vars = names(fcts2[-which(names(fcts2) %in% fcts3)]),
                           n = fcts2[-which(names(fcts2) %in% fcts3)]))

    # Regression on optimal predictors
    cvm <- cv.glmnet(x = as.matrix(ppX),y = tmpDatD[[Y]],nlambda = 30)
    
    m <- cvm$glmnet.fit
    vimpTmp <- m$beta %>%
      as.matrix() %>%
      data.frame() %>%
      mutate(vars = row.names(.)) %>%
      as_tibble() %>%
      gather(normInd,coef,-vars)
    
    lambdaLookup <- data.frame(normInd = unique(vimpTmp$normInd),
                               lambda = m$lambda,
                               norm = apply(abs(m$beta), 2, sum),
                               dev.ratio = m$dev.ratio,
                               stringsAsFactors = F)
    
    lambdaLookup$minInd <- cvm$index[1]
    lambdaLookup$seInd <- cvm$index[2]
    lambdaLookup$minLambda <- cvm$lambda.min
    lambdaLookup$seLambda <- cvm$lambda.1se
    
    lassoRes <- vimpTmp %>%
      filter(coef != 0) %>%
      left_join(lambdaLookup,by = 'normInd') %>%
      mutate(bsInd = bsInd,
             outcome = Y,
             normInd = as.integer(gsub('s','',normInd)),
             period = as.character(d)) %>%
      left_join(counts,by = 'vars') %>%
      bind_rows(lassoRes)
    
    keeps <- cvm$glmnet.fit$beta[,cvm$index['1se',]]
    keeps <- names(keeps[which(abs(keeps)>0)])
    
    keeps2 <- sapply(colnames(tmpDat2),function(x) any(grepl(x,keeps)))
    covs <- names(keeps2[which(keeps2)])
    if(any(covs %in% Ys)) {
      covs <- covs[-which(covs %in% Ys)]
    }
    if(length(covs) == 0) { next }
    
    m <- feols(as.formula(paste0(Y,' ~ ',paste(covs,collapse = " + ")," | STATE")),
               tmpDat2 %>%
                 mutate_if(is.numeric,function(x) scale(x)[,1]),warn = FALSE,verbose = 0)
    

    tmpCounts <- NULL
    for(cov in fcts) {
      tmpCounts <- tmpDat2 %>%
        count(vars = get(cov)) %>%
        mutate(vars = paste0(cov,vars)) %>%
        bind_rows(tmpCounts)
    }
    
    coefRes <- m$coeftable %>%
      data.frame() %>%
      mutate(vars = row.names(.)) %>%
      as_tibble() %>%
      rename(est = Estimate,se = Std..Error,tstat = t.value,pval = Pr...t..) %>%
      mutate(bsInd = bsInd,
             outcome = Y,
             period = as.character(d)) %>%
      left_join(tmpCounts,by= 'vars') %>%
      bind_rows(coefRes)
    
  }
  cat(as.character(d),'in',round(difftime(Sys.time(),zz,units = 'mins'),2),'minutes\n')
  zz <- Sys.time()
}


save(coefRes,lassoRes,lambdaLookup,file = paste0('./results/VIMP_lasso/LASSO_2024_outcome-',Y,'_period-',args[2],'_temp-',args[3],'.RData'))
